#!/usr/bin/env python
# -*- coding: utf-8 -*-

from __future__ import absolute_import
from __future__ import print_function
from __future__ import division
import warnings

import tensorflow as tf
from tensorflow.python.client.session import (
    register_session_run_conversion_functions)
import six

from zhusuan import distributions
from zhusuan.utils import TensorArithmeticMixin
from zhusuan.framework.meta_bn import Local, MetaBayesianNet
from zhusuan.framework.utils import Context


__all__ = [
    'StochasticTensor',
    'BayesianNet',
]


class StochasticTensor(TensorArithmeticMixin):
    """
    The :class:`StochasticTensor` class represents the stochastic nodes in a
    :class:`BayesianNet`.

    We can use any distribution available in :mod:`zhusuan.distributions` to
    construct a stochastic node in a :class:`BayesianNet`. For example::

        bn = zs.BayesianNet()
        x = bn.normal("x", 0., std=1.)

    will build a stochastic node in ``bn`` with the
    :class:`~zhusuan.distributions.univariate.Normal` distribution. The
    returned ``x`` will be a :class:`StochasticTensor`. The second line is
    equivalent to::

        dist = zs.distributions.Normal(0., std=1.)
        x = bn.stochastic("x", dist)

    :class:`StochasticTensor` instances are Tensor-like, which means that
    they can be passed into any Tensorflow operations. This makes it easy
    to build Bayesian networks by mixing stochastic nodes and Tensorflow
    primitives.

    .. seealso::

        For more information, please refer to :doc:`/tutorials/concepts`.

    :param bn: A :class:`BayesianNet`.
    :param name: A string. The name of the :class:`StochasticTensor`. Must be
        unique in a :class:`BayesianNet`.
    :param dist: A :class:`~zhusuan.distributions.base.Distribution`
        instance that determines the distribution used in this stochastic node.
    :param observation: A Tensor, which matches the shape of `dist`. If
        specified, then the :class:`StochasticTensor` is observed and
        the :attr:`tensor` property will return the `observation`. This
        argument will overwrite the observation provided in
        :meth:`zhusuan.framework.meta_bn.MetaBayesianNet.observe`.
    :param n_samples: A 0-D `int32` Tensor. Number of samples generated by
        this :class:`StochasticTensor`.
    """

    def __init__(self, bn, name, dist, observation=None, **kwargs):
        if bn is None:
            warnings.warn(
                "The old-style StochasticTensor wrappers will be removed "
                "in a future version. Please see tutorials/concepts.rst for "
                "the suggested way of model construction.",
                FutureWarning)
            try:
                bn = BayesianNet.get_context()
            except RuntimeError:
                pass
            else:
                bn.nodes[name] = self

        self._bn = bn
        self._name = name
        self._dist = dist
        self._dtype = dist.dtype
        self._n_samples = kwargs.get("n_samples", None)
        if observation is not None:
            self._observation = self._check_observation(observation)
        elif (self._bn is not None) and (self._name in self._bn._observed):
            self._observation = self._check_observation(
                self._bn._observed[name])
        else:
            self._observation = None
        super(StochasticTensor, self).__init__()

    def _check_observation(self, observation):
        type_msg = "Incompatible types of {}('{}') and its observation: {}"
        try:
            observation = tf.convert_to_tensor(observation, dtype=self._dtype)
        except ValueError as e:
            raise type(e)(
                type_msg.format(self.__class__.__name__, self._name, e))

        shape_msg = "Incompatible shapes of {}('{}') and its observation: " \
                    "{} vs {}."
        dist_shape = self._dist.get_batch_shape().concatenate(
            self._dist.get_value_shape())
        try:
            tf.broadcast_static_shape(dist_shape, observation.get_shape())
        except ValueError as e:
            raise type(e)(
                shape_msg.format(
                    self.__class__.__name__, self._name, dist_shape,
                    observation.get_shape()))
        return observation

    @property
    def bn(self):
        """
        The :class:`BayesianNet` where the :class:`StochasticTensor` lives.

        :return: A :class:`BayesianNet` instance.
        """
        return self._bn

    # TODO: __str__, __repr__ for StochasticTensor

    @property
    def name(self):
        """
        The name of the :class:`StochasticTensor`.

        :return: A string.
        """
        return self._name

    @property
    def dtype(self):
        """
        The sample type of the :class:`StochasticTensor`.

        :return: A ``DType`` instance.
        """
        return self._dtype

    @property
    def dist(self):
        """
         The distribution followed by the :class:`StochasticTensor`.

        :return: A :class:`~zhusuan.distributions.base.Distribution` instance.
        """
        return self._dist

    def is_observed(self):
        """
        Whether the :class:`StochasticTensor` is observed or not.

        :return: A bool.
        """
        return self._observation is not None

    @property
    def tensor(self):
        """
        The value of this :class:`StochasticTensor`. If it is observed, then
        the observation is returned, otherwise samples are returned.

        :return: A Tensor.
        """
        if self._observation is not None:
            return self._observation
        elif not hasattr(self, "_samples"):
            self._samples = self._dist.sample(n_samples=self._n_samples)
        return self._samples

    @property
    def shape(self):
        """
        Return the static shape of this :class:`StochasticTensor`.

        :return: A ``TensorShape`` instance.
        """
        return self.tensor.shape

    def get_shape(self):
        """
        Alias of :attr:`shape`.

        :return: A ``TensorShape`` instance.
        """
        return self.shape

    @property
    def cond_log_p(self):
        """
        The conditional log probability of the :class:`StochasticTensor`,
        evaluated at its current value (given by :attr:`tensor`).

        :return: A Tensor.
        """
        if not hasattr(self, "_cond_log_p"):
            self._cond_log_p = self._dist.log_prob(self.tensor)
        return self._cond_log_p

    @staticmethod
    def _to_tensor(value, dtype=None, name=None, as_ref=False):
        if dtype and not dtype.is_compatible_with(value.dtype):
            raise ValueError("Incompatible type conversion requested to type "
                             "'{}' for variable of type '{}'".
                             format(dtype.name, value.dtype.name))
        if as_ref:
            raise ValueError("{}: Ref type not supported.".format(value))
        return value.tensor

    # Below are deprecated features:

    @property
    def net(self):
        """
        .. warning::

            Deprecated in 0.4, will be removed in 0.4.1.

        The :class:`BayesianNet` where the :class:`StochasticTensor` lives.

        :return: A :class:`BayesianNet` instance.
        """
        warnings.warn(
            "StochasticTensor: The `.net` property will be removed in the "
            "coming version (0.4.1), use `.bn` instead.",
            FutureWarning)
        return self._bn

    @property
    def distribution(self):
        """
        .. warning::

            Deprecated in 0.4, will be removed in 0.4.1.

        The distribution followed by the :class:`StochasticTensor`.

        :return: A :class:`~zhusuan.distributions.base.Distribution` instance.
        """
        warnings.warn(
            "StochasticTensor: The `.distribution` property will be removed "
            "in the coming version (0.4.1), use `.dist` instead.",
            FutureWarning)
        return self._dist

    def sample(self, n_samples):
        """
        .. warning::

            Deprecated in 0.4, will be removed in 0.4.1.

        Sample from the underlying distribution.

        :param n_samples: A 0-D `int32` Tensor. The number of samples.
        :return: A Tensor.
        """
        warnings.warn(
            "StochasticTensor: The `sample()` method will be removed "
            "in the coming version (0.4.1), use `.dist.sample()` instead.",
            FutureWarning)
        return self._dist.sample(n_samples)

    def log_prob(self, given):
        """
        .. warning::

            Deprecated in 0.4, will be removed in 0.4.1.

        Compute the log probability density (mass) function of
        the underlying distribution at the `given` value.

        :param given: A Tensor.
        :return: A Tensor. The log probability value.
        """
        warnings.warn(
            "StochasticTensor: The `log_prob()` method will be removed "
            "in the coming version (0.4.1), use `.dist.log_prob()` instead.",
            FutureWarning)
        return self._dist.log_prob(given)

    def prob(self, given):
        """
        .. warning::

            Deprecated in 0.4, will be removed in 0.4.1.

        Compute the probability density (mass) function of
        the underlying distribution at the `given` value.

        :param given: A Tensor.
        :return: A Tensor. The probability value.
        """
        warnings.warn(
            "StochasticTensor: The `prob()` method will be removed "
            "in the coming version (0.4.1), use `.dist.prob()` instead.",
            FutureWarning)
        return self._dist.prob(given)


tf.register_tensor_conversion_function(
    StochasticTensor, StochasticTensor._to_tensor)

# bring support for session.run(StochasticTensor), and for using as keys
# in feed_dict.
register_session_run_conversion_functions(
    StochasticTensor,
    fetch_function=lambda t: ([t.tensor], lambda val: val[0]),
    feed_function=lambda t, v: [(t.tensor, v)],
    feed_function_for_partial_run=lambda t: [t.tensor]
)


class _BayesianNet(object):
    def __init__(self):
        self._nodes = {}
        try:
            self._local_cxt = Local.get_context()
        except RuntimeError:
            self._local_cxt = None
        if self._local_cxt:
            self._meta_bn = self._local_cxt.meta_bn
        else:
            self._meta_bn = None
        super(_BayesianNet, self).__init__()

    @property
    def nodes(self):
        """
        The dictionary of all named nodes in this :class:`BayesianNet`,
        including all :class:`StochasticTensor` s and named deterministic nodes.

        :return: A dict.
        """
        return self._nodes

    def _get_observation(self, name):
        if self._local_cxt:
            ret = self._local_cxt.observations.get(name, None)
            return ret
        return None

    def stochastic(self, name, dist, **kwargs):
        """
        Add a stochastic node in this :class:`BayesianNet`.

        :param name: The name of the stochastic node. Must be unique in a
            :class:`BayesianNet`.
        :param dist: The followed distribution.
        :param kwargs: Optional parameters to specify the sampling behaviors,

            * n_samples: A 0-D `int32` Tensor. Number of samples generated.

        :return: A :class:`StochasticTensor`.
        """
        if name in self._nodes:
            raise ValueError(
                "There exists a node with name '{}' in the {}. Names should "
                "be unique.".format(name, BayesianNet.__name__))
        # invalidate the log joint cache
        if hasattr(self, "_log_joint_cache"):
            del self._log_joint_cache
        node = StochasticTensor(
            self, name, dist, observation=self._get_observation(name), **kwargs)
        self._nodes[name] = node
        return node

    def deterministic(self, name, input_tensor):
        """
        Add a named deterministic node in this :class:`BayesianNet`.

        :param name: The name of the deterministic node. Must be unique in a
            :class:`BayesianNet`.
        :param input_tensor: A Tensor. The value of the deterministic node.

        :return: A Tensor. The same as `input_tensor`.
        """
        input_tensor = tf.convert_to_tensor(input_tensor)
        self._nodes[name] = input_tensor
        return input_tensor

    def _check_name_exist(self, name, only_stochastic=False):
        if not isinstance(name, six.string_types):
            raise TypeError(
                "Expected string in `name_or_names`, got {} of type {}."
                .format(repr(name), type(name)))
        if name not in self._nodes:
            raise ValueError("There isn't a node named '{}' in the {}."
                             .format(name, BayesianNet.__name__))
        elif only_stochastic and not isinstance(
                self._nodes[name], StochasticTensor):
            raise ValueError("Node '{}' is deterministic (input or output)."
                             .format(name))
        return name

    def _check_names_exist(self, name_or_names, only_stochastic=False):
        """
        Check if there are ``StochasticTensor`` s with `name_or_names` in the
        net.

        :param name_or_names: A string or a tuple(list) of strings. Names of
            ``StochasticTensor`` s in the net.
        :param only_stochastic: A bool. Whether to check only in stochastic
            nodes. Default is `False`.

        :return: The validated name, or a tuple of the validated names.
        """
        if isinstance(name_or_names, six.string_types):
            names = (name_or_names,)
        else:
            name_or_names = tuple(name_or_names)
            names = name_or_names
        for name in names:
            _ = self._check_name_exist(name, only_stochastic=only_stochastic)
        return name_or_names

    def get(self, name_or_names):
        """
        Get one or several nodes by name. For a single node, one can also use
        dictionary-like ``bn[name]`` to get the node.

        :param name_or_names: A string or a tuple(list) of strings.
        :return: A Tensor/:class:`StochasticTensor` or a list of
            Tensor/:class:`StochasticTensor` s.
        """
        name_or_names = self._check_names_exist(name_or_names)
        if isinstance(name_or_names, tuple):
            return [self._nodes[name] for name in name_or_names]
        else:
            return self._nodes[name_or_names]

    def cond_log_prob(self, name_or_names):
        """
        The conditional log probabilities of stochastic nodes,
        evaluated at their current values (given by
        :attr:`StochasticTensor.tensor`).

        :param name_or_names: A string or a list of strings. Name(s) of the
            stochastic nodes.
        :return: A Tensor or a list of Tensors.
        """
        name_or_names = self._check_names_exist(name_or_names,
                                                only_stochastic=True)
        if isinstance(name_or_names, tuple):
            return [self._nodes[name].cond_log_p for name in name_or_names]
        else:
            return self._nodes[name_or_names].cond_log_p

    def _log_joint(self):
        if (self._meta_bn is None) or (self._meta_bn.log_joint is None):
            ret = sum(node.cond_log_p for node in six.itervalues(self._nodes)
                      if isinstance(node, StochasticTensor))
        elif callable(self._meta_bn.log_joint):
            ret = self._meta_bn.log_joint(self)
        else:
            raise TypeError(
                "{}.log_joint is set to a non-callable instance: {}"
                .format(self._meta_bn.__class__.__name__,
                        repr(self._meta_bn.log_joint)))
        return ret

    def log_joint(self):
        """
        The default log joint probability of this :class:`BayesianNet`.
        It works by summing over all the conditional log probabilities of
        stochastic nodes evaluated at their current values (samples or
        observations).

        :return: A Tensor.
        """
        if not hasattr(self, "_log_joint_cache"):
            self._log_joint_cache = self._log_joint()
        return self._log_joint_cache

    def __getitem__(self, name):
        name = self._check_name_exist(name)
        return self._nodes[name]

    def __setitem__(self, name, node):
        raise TypeError(
            "{} instance does not support replacement of the existing node. "
            "To achieve this, pass observations of certain nodes when "
            "calling {}.{}".format(
                BayesianNet.__name__, MetaBayesianNet.__name__,
                MetaBayesianNet.observe.__name__))


class BayesianNet(_BayesianNet, Context):
    """
    The :class:`BayesianNet` class provides a convenient way to construct
    Bayesian networks, i.e., directed graphical models.

    To start, we create a :class:`BayesianNet` instance::

        bn = zs.BayesianNet()

    A :class:`BayesianNet` keeps two kinds of nodes

    * deterministic nodes: they are just Tensors, usually the outputs of
      Tensorflow operations.
    * stochastic nodes: they are random variables in graphical models, and can
      be constructed like

    ::

        w = bn.normal("w", 0., std=alpha)

    Here ``w`` is a :class:`StochasticTensor` that follows the
    :class:`~zhusuan.distributions.univariate.Normal` distribution. For any
    distribution available in :mod:`zhusuan.distributions`, we can find
    a method of :class:`BayesianNet` for creating the corresponding stochastic
    node. If you define your own distribution class, then there is a
    general method :meth:`stochastic` for doing this::

        dist = CustomizedDistribution()
        w = bn.stochastic("w", dist)

    To observe any stochastic nodes in the network, pass a dictionary mapping
    of ``(name, Tensor)`` pairs when constructing :class:`BayesianNet`.
    This will assign observed values to corresponding
    :class:`StochasticTensor` s. For example::

        bn = zs.BayesianNet(observed={"w": w_obs})

    will set ``w`` to be observed.

    .. note::

        The observation passed must have the same type and shape as the
        :class:`StochasticTensor`.

    A useful case is that we often need to pass different observations more
    than once into the Bayesian network, for which we provide
    :func:`~zhusuan.framework.meta_bn.meta_bayesian_net` decorator and another
    abstract class :class:`~zhusuan.framework.meta_bn.MetaBayesianNet`.

    .. seealso::

        For more details and examples, please refer to
        :doc:`/tutorials/concepts`.

    :param observed: A dictionary of (string, Tensor) pairs, which maps from
        names of stochastic nodes to their observed values.
    """

    def __init__(self, observed=None):
        # To support deprecated features
        self._observed = observed if observed else {}
        super(BayesianNet, self).__init__()

    def normal(self,
               name,
               mean=0.,
               _sentinel=None,
               std=None,
               logstd=None,
               group_ndims=0,
               n_samples=None,
               is_reparameterized=True,
               check_numerics=False,
               **kwargs):
        """
        Add a stochastic node in this :class:`BayesianNet` that follows the
        Normal distribution.

        :param name: The name of the stochastic node. Must be unique in a
            :class:`BayesianNet`.

        See
        :class:`~zhusuan.distributions.univariate.Normal` for more information
        about the other arguments.

        :return: A :class:`StochasticTensor` instance.
        """
        dist = distributions.Normal(
            mean,
            _sentinel=_sentinel,
            std=std,
            logstd=logstd,
            group_ndims=group_ndims,
            is_reparameterized=is_reparameterized,
            check_numerics=check_numerics,
            **kwargs
        )
        return self.stochastic(name, dist, n_samples=n_samples, **kwargs)

    def fold_normal(self,
                    name,
                    mean=0.,
                    _sentinel=None,
                    std=None,
                    logstd=None,
                    n_samples=None,
                    group_ndims=0,
                    is_reparameterized=True,
                    check_numerics=False,
                    **kwargs):
        """
        Add a stochastic node in this :class:`BayesianNet` that follows the
        FoldNormal distribution.

        :param name: The name of the stochastic node. Must be unique in a
            :class:`BayesianNet`.

        See
        :class:`~zhusuan.distributions.univariate.FoldNormal` for more
        information about the other arguments.

        :return: A :class:`StochasticTensor` instance.
        """
        dist = distributions.FoldNormal(
            mean,
            _sentinel=_sentinel,
            std=std,
            logstd=logstd,
            group_ndims=group_ndims,
            is_reparameterized=is_reparameterized,
            check_numerics=check_numerics,
            **kwargs
        )
        return self.stochastic(name, dist, n_samples=n_samples, **kwargs)

    def bernoulli(self,
                  name,
                  logits,
                  n_samples=None,
                  group_ndims=0,
                  dtype=tf.int32,
                  **kwargs):
        """
        Add a stochastic node in this :class:`BayesianNet` that follows the
        Bernoulli distribution.

        :param name: The name of the stochastic node. Must be unique in a
            :class:`BayesianNet`.

        See
        :class:`~zhusuan.distributions.univariate.Bernoulli` for more
        information about the other arguments.

        :return: A :class:`StochasticTensor` instance.
        """
        dist = distributions.Bernoulli(
            logits,
            group_ndims=group_ndims,
            dtype=dtype,
            **kwargs
        )
        return self.stochastic(name, dist, n_samples=n_samples, **kwargs)

    def categorical(self,
                    name,
                    logits,
                    n_samples=None,
                    group_ndims=0,
                    dtype=tf.int32,
                    **kwargs):
        """
        Add a stochastic node in this :class:`BayesianNet` that follows the
        Categorical distribution.

        :param name: The name of the stochastic node. Must be unique in a
            :class:`BayesianNet`.

        See
        :class:`~zhusuan.distributions.univariate.Categorical` for more
        information about the other arguments.

        :return: A :class:`StochasticTensor` instance.
        """
        dist = distributions.Categorical(
            logits,
            group_ndims=group_ndims,
            dtype=dtype,
            **kwargs
        )
        return self.stochastic(name, dist, n_samples=n_samples, **kwargs)

    discrete = categorical

    def uniform(self,
                name,
                minval=0.,
                maxval=1.,
                n_samples=None,
                group_ndims=0,
                is_reparameterized=True,
                check_numerics=False,
                **kwargs):
        """
        Add a stochastic node in this :class:`BayesianNet` that follows the
        Uniform distribution.

        :param name: The name of the stochastic node. Must be unique in a
            :class:`BayesianNet`.

        See
        :class:`~zhusuan.distributions.univariate.Uniform` for more
        information about the other arguments.

        :return: A :class:`StochasticTensor` instance.
        """
        dist = distributions.Uniform(
            minval,
            maxval,
            group_ndims=group_ndims,
            is_reparameterized=is_reparameterized,
            check_numerics=check_numerics,
            **kwargs
        )
        return self.stochastic(name, dist, n_samples=n_samples, **kwargs)

    def gamma(self,
              name,
              alpha,
              beta,
              n_samples=None,
              group_ndims=0,
              check_numerics=False,
              **kwargs):
        """
        Add a stochastic node in this :class:`BayesianNet` that follows the
        Gamma distribution.

        :param name: The name of the stochastic node. Must be unique in a
            :class:`BayesianNet`.

        See
        :class:`~zhusuan.distributions.univariate.Gamma` for more information
        about the other arguments.

        :return: A :class:`StochasticTensor` instance.
        """
        dist = distributions.Gamma(
            alpha,
            beta,
            group_ndims=group_ndims,
            check_numerics=check_numerics,
            **kwargs
        )
        return self.stochastic(name, dist, n_samples=n_samples, **kwargs)

    def beta(self,
             name,
             alpha,
             beta,
             n_samples=None,
             group_ndims=0,
             check_numerics=False,
             **kwargs):
        """
        Add a stochastic node in this :class:`BayesianNet` that follows the
        Beta distribution.

        :param name: The name of the stochastic node. Must be unique in a
            :class:`BayesianNet`.

        See
        :class:`~zhusuan.distributions.univariate.Beta` for more information
        about the other arguments.

        :return: A :class:`StochasticTensor` instance.
        """
        dist = distributions.Beta(
            alpha,
            beta,
            group_ndims=group_ndims,
            check_numerics=check_numerics,
            **kwargs
        )
        return self.stochastic(name, dist, n_samples=n_samples, **kwargs)

    def poisson(self,
                name,
                rate,
                n_samples=None,
                group_ndims=0,
                dtype=tf.int32,
                check_numerics=False,
                **kwargs):
        """
        Add a stochastic node in this :class:`BayesianNet` that follows the
        Poisson distribution.

        :param name: The name of the stochastic node. Must be unique in a
            :class:`BayesianNet`.

        See
        :class:`~zhusuan.distributions.univariate.Poisson` for more information
        about the other arguments.

        :return: A :class:`StochasticTensor` instance.
        """
        dist = distributions.Poisson(
            rate,
            group_ndims=group_ndims,
            dtype=dtype,
            check_numerics=check_numerics,
            **kwargs
        )
        return self.stochastic(name, dist, n_samples=n_samples, **kwargs)

    def binomial(self,
                 name,
                 logits,
                 n_experiments,
                 n_samples=None,
                 group_ndims=0,
                 dtype=tf.int32,
                 check_numerics=False,
                 **kwargs):
        """
        Add a stochastic node in this :class:`BayesianNet` that follows the
        Binomial distribution.

        :param name: The name of the stochastic node. Must be unique in a
            :class:`BayesianNet`.

        See
        :class:`~zhusuan.distributions.univariate.Binomial` for more information
        about the other arguments.

        :return: A :class:`StochasticTensor` instance.
        """
        dist = distributions.Binomial(
            logits,
            n_experiments,
            group_ndims=group_ndims,
            dtype=dtype,
            check_numerics=check_numerics,
            **kwargs
        )
        return self.stochastic(name, dist, n_samples=n_samples, **kwargs)

    def multivariate_normal_cholesky(self,
                                     name,
                                     mean,
                                     cov_tril,
                                     n_samples=None,
                                     group_ndims=0,
                                     is_reparameterized=True,
                                     check_numerics=False,
                                     **kwargs):
        """
        Add a stochastic node in this :class:`BayesianNet` that follows the
        MultivariateNormalCholesky distribution.

        :param name: The name of the stochastic node. Must be unique in a
            :class:`BayesianNet`.

        See
        :class:`~zhusuan.distributions.multivariate.MultivariateNormalCholesky`
        for more information about the other arguments.

        :return: A :class:`StochasticTensor` instance.
        """
        dist = distributions.MultivariateNormalCholesky(
            mean,
            cov_tril,
            group_ndims,
            is_reparameterized=is_reparameterized,
            check_numerics=check_numerics,
            **kwargs
        )
        return self.stochastic(name, dist, n_samples=n_samples, **kwargs)

    def matrix_variate_normal_cholesky(self,
                                       name,
                                       mean,
                                       u_tril,
                                       v_tril,
                                       n_samples=None,
                                       group_ndims=0,
                                       is_reparameterized=True,
                                       check_numerics=False,
                                       **kwargs):
        """
        Add a stochastic node in this :class:`BayesianNet` that follows the
        MatrixVariateNormalCholesky distribution.

        :param name: The name of the stochastic node. Must be unique in a
            :class:`BayesianNet`.

        See
        :class:`~zhusuan.distributions.multivariate.MatrixVariateNormalCholesky`
        for more information about the other arguments.

        :return: A :class:`StochasticTensor` instance.
        """
        dist = distributions.MatrixVariateNormalCholesky(
            mean,
            u_tril,
            v_tril,
            group_ndims,
            is_reparameterized=is_reparameterized,
            check_numerics=check_numerics,
            **kwargs
        )
        return self.stochastic(name, dist, n_samples=n_samples, **kwargs)

    def multinomial(self,
                    name,
                    logits,
                    n_experiments,
                    normalize_logits=True,
                    n_samples=None,
                    group_ndims=0,
                    dtype=tf.int32,
                    **kwargs):
        """
        Add a stochastic node in this :class:`BayesianNet` that follows the
        Multinomial distribution.

        :param name: The name of the stochastic node. Must be unique in a
            :class:`BayesianNet`.

        See
        :class:`~zhusuan.distributions.multivariate.Multinomial`
        for more information about the other arguments.

        :return: A :class:`StochasticTensor` instance.
        """
        dist = distributions.Multinomial(
            logits,
            n_experiments,
            normalize_logits=normalize_logits,
            group_ndims=group_ndims,
            dtype=dtype,
            **kwargs
        )
        return self.stochastic(name, dist, n_samples=n_samples, **kwargs)

    def unnormalized_multinomial(self,
                                 name,
                                 logits,
                                 normalize_logits=True,
                                 group_ndims=0,
                                 dtype=tf.int32,
                                 **kwargs):
        """
        Add a stochastic node in this :class:`BayesianNet` that follows the
        UnnormalizedMultinomial distribution.

        :param name: The name of the stochastic node. Must be unique in a
            :class:`BayesianNet`.

        See
        :class:`~zhusuan.distributions.multivariate.UnnormalizedMultinomial`
        for more information about the other arguments.

        :return: A :class:`StochasticTensor` instance.
        """
        dist = distributions.UnnormalizedMultinomial(
            logits,
            normalize_logits=normalize_logits,
            group_ndims=group_ndims,
            dtype=dtype,
            **kwargs
        )
        return self.stochastic(name, dist, **kwargs)

    bag_of_categoricals = unnormalized_multinomial

    def onehot_categorical(self,
                           name,
                           logits,
                           n_samples=None,
                           group_ndims=0,
                           dtype=tf.int32,
                           **kwargs):
        """
        Add a stochastic node in this :class:`BayesianNet` that follows the
        OnehotCategorical distribution.

        :param name: The name of the stochastic node. Must be unique in a
            :class:`BayesianNet`.

        See
        :class:`~zhusuan.distributions.multivariate.OnehotCategorical`
        for more information about the other arguments.

        :return: A :class:`StochasticTensor` instance.
        """
        dist = distributions.OnehotCategorical(
            logits,
            group_ndims=group_ndims,
            dtype=dtype,
            **kwargs
        )
        return self.stochastic(name, dist, n_samples=n_samples, **kwargs)

    onehot_discrete = onehot_categorical

    def dirichlet(self,
                  name,
                  alpha,
                  n_samples=None,
                  group_ndims=0,
                  check_numerics=False,
                  **kwargs):
        """
        Add a stochastic node in this :class:`BayesianNet` that follows the
        Dirichlet distribution.

        :param name: The name of the stochastic node. Must be unique in a
            :class:`BayesianNet`.

        See
        :class:`~zhusuan.distributions.multivariate.Dirichlet`
        for more information about the other arguments.

        :return: A :class:`StochasticTensor` instance.
        """
        dist = distributions.Dirichlet(
            alpha,
            group_ndims=group_ndims,
            check_numerics=check_numerics,
            **kwargs
        )
        return self.stochastic(name, dist, n_samples=n_samples, **kwargs)

    def inverse_gamma(self,
                      name,
                      alpha,
                      beta,
                      n_samples=None,
                      group_ndims=0,
                      check_numerics=False,
                      **kwargs):
        """
        Add a stochastic node in this :class:`BayesianNet` that follows the
        InverseGamma distribution.

        :param name: The name of the stochastic node. Must be unique in a
            :class:`BayesianNet`.

        See
        :class:`~zhusuan.distributions.univariate.InverseGamma`
        for more information about the other arguments.

        :return: A :class:`StochasticTensor` instance.
        """
        dist = distributions.InverseGamma(
            alpha,
            beta,
            group_ndims=group_ndims,
            check_numerics=check_numerics,
            **kwargs
        )
        return self.stochastic(name, dist, n_samples=n_samples, **kwargs)

    def laplace(self,
                name,
                loc,
                scale,
                n_samples=None,
                group_ndims=0,
                is_reparameterized=True,
                check_numerics=False,
                **kwargs):
        """
        Add a stochastic node in this :class:`BayesianNet` that follows the
        Laplace distribution.

        :param name: The name of the stochastic node. Must be unique in a
            :class:`BayesianNet`.

        See
        :class:`~zhusuan.distributions.univariate.Laplace`
        for more information about the other arguments.

        :return: A :class:`StochasticTensor` instance.
        """
        dist = distributions.Laplace(
            loc,
            scale,
            group_ndims=group_ndims,
            is_reparameterized=is_reparameterized,
            check_numerics=check_numerics,
            **kwargs
        )
        return self.stochastic(name, dist, n_samples=n_samples, **kwargs)

    def bin_concrete(self,
                     name,
                     temperature,
                     logits,
                     n_samples=None,
                     group_ndims=0,
                     is_reparameterized=True,
                     check_numerics=False,
                     **kwargs):
        """
        Add a stochastic node in this :class:`BayesianNet` that follows the
        BinConcrete distribution.

        :param name: The name of the stochastic node. Must be unique in a
            :class:`BayesianNet`.

        See
        :class:`~zhusuan.distributions.univariate.BinConcrete`
        for more information about the other arguments.

        :return: A :class:`StochasticTensor` instance.
        """
        dist = distributions.BinConcrete(
            temperature,
            logits,
            group_ndims=group_ndims,
            is_reparameterized=is_reparameterized,
            check_numerics=check_numerics,
            **kwargs
        )
        return self.stochastic(name, dist, n_samples=n_samples, **kwargs)

    bin_gumbel_softmax = bin_concrete

    def exp_concrete(self,
                     name,
                     temperature,
                     logits,
                     n_samples=None,
                     group_ndims=0,
                     is_reparameterized=True,
                     check_numerics=False,
                     **kwargs):
        """
        Add a stochastic node in this :class:`BayesianNet` that follows the
        ExpConcrete distribution.

        :param name: The name of the stochastic node. Must be unique in a
            :class:`BayesianNet`.

        See
        :class:`~zhusuan.distributions.multivariate.ExpConcrete`
        for more information about the other arguments.

        :return: A :class:`StochasticTensor` instance.
        """
        dist = distributions.ExpConcrete(
            temperature,
            logits,
            group_ndims=group_ndims,
            is_reparameterized=is_reparameterized,
            check_numerics=check_numerics,
            **kwargs
        )
        return self.stochastic(name, dist, n_samples=n_samples, **kwargs)

    exp_gumbel_softmax = exp_concrete

    def concrete(self,
                 name,
                 temperature,
                 logits,
                 n_samples=None,
                 group_ndims=0,
                 is_reparameterized=True,
                 check_numerics=False,
                 **kwargs):
        """
        Add a stochastic node in this :class:`BayesianNet` that follows the
        Concrete distribution.

        :param name: The name of the stochastic node. Must be unique in a
            :class:`BayesianNet`.

        See
        :class:`~zhusuan.distributions.multivariate.Concrete`
        for more information about the other arguments.

        :return: A :class:`StochasticTensor` instance.
        """
        dist = distributions.Concrete(
            temperature,
            logits,
            group_ndims=group_ndims,
            is_reparameterized=is_reparameterized,
            check_numerics=check_numerics,
            **kwargs
        )
        return self.stochastic(name, dist, n_samples=n_samples, **kwargs)

    gumbel_softmax = concrete

    # Below are deprecated features:

    def __enter__(self):
        warnings.warn(
            "Using `BayesianNet` as contexts has been deprecated in 0.4. "
            "Please see tutorials/concepts.rst for the suggested way of "
            "model construction.", FutureWarning)
        return super(BayesianNet, self).__enter__()

    def outputs(self, name_or_names):
        """
        .. note::

            Deprecated in 0.4, will be removed in 0.4.1.

        """
        warnings.warn(
            "BayesianNet: `outputs()` has been deprecated in 0.4 and will "
            "be removed in 0.4.1, use `get()` instead.", FutureWarning)
        name_or_names = self._check_names_exist(name_or_names)
        if isinstance(name_or_names, tuple):
            return [self._nodes[name].tensor for name in name_or_names]
        else:
            return self._nodes[name_or_names].tensor

    def local_log_prob(self, name_or_names):
        """
        .. note::

            Deprecated in 0.4, will be removed in 0.4.1.
        """
        warnings.warn(
            "BayesianNet: `local_log_prob()` has been deprecated in 0.4 "
            "and will be removed in 0.4.1, use `cond_log_prob()` instead.",
            FutureWarning)
        return self.cond_log_prob(name_or_names)

    def query(self, name_or_names, outputs=False, local_log_prob=False):
        """
        .. note::

            Deprecated in 0.4, will be removed in 0.4.1.
        """
        warnings.warn(
            "BayesianNet: `query()` has been deprecated in 0.4 "
            "and will be removed in 0.4.1, use `get()` and "
            "`cond_log_prob()` instead.", FutureWarning)
        name_or_names = self._check_names_exist(name_or_names)
        ret = []
        if outputs:
            ret.append(self.outputs(name_or_names))
        if local_log_prob:
            ret.append(self.local_log_prob(name_or_names))
        if len(ret) == 0:
            raise ValueError("No query options are selected.")
        elif isinstance(name_or_names, tuple):
            return list(zip(*ret))
        else:
            return tuple(ret)
